import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch as t
from functions import deterministic_NeuralSort, torch_get_position_ctr

device = t.device("cuda" if t.cuda.is_available() else "cpu")

class Args:
    def __init__(self, args):
        self.num_agent = args[0]
        self.num_item = args[1]
        self.distribution_type = args[2]
        self.num_linear = args[3]
        self.num_max = args[4]
        self.num_sample_train = args[5]
        self.num_sample_test = args[6]
        self.seed_val = args[7]

class Score_VCG(nn.Module):
    def __init__(self, args):
        super(Score_VCG, self).__init__()
        self.args = args
        num_agents = self.args.num_agent

        # Initialize weights and biases
        self.seller_w = t.tensor(np.random.normal(size=(num_agents)) / 5, device=device, requires_grad=True)
        self.seller_w2 = t.tensor(np.random.normal(size=(num_agents)) / 5, device=device, requires_grad=True)
        self.seller_b = t.tensor(-np.random.rand(num_agents) * 1.0, device=device, requires_grad=True)

    def forward(self, x, ctr_ads, ctr_og, cvr_ads, cvr_og, alpha, beta):
        batch_size = x.shape[0]
        num_agents = self.args.num_agent

        # Convert inputs to tensors
        x = t.tensor(x, device=device)
        ctr_ads = t.tensor(ctr_ads, device=device).float()
        ctr_og = t.tensor(ctr_og, device=device).float()
        cvr_ads = t.tensor(cvr_ads, device=device).float()
        cvr_og = t.tensor(cvr_og, device=device).float()

        # Repeat weights and biases for batch processing
        w_copy = self.seller_w.repeat(batch_size, 1)
        b_copy = self.seller_b.repeat(batch_size, 1)

        # Compute value for ads and organic results
        vv = x * t.exp(w_copy) + b_copy
        value_ads = vv * ctr_ads + alpha * ctr_ads + beta * cvr_ads
        value_org = alpha * ctr_og + beta * cvr_og
        value_combined = t.hstack((value_ads, value_org))

        # Sort values using deterministic neural sort
        value_ads_sorted_indices = deterministic_NeuralSort(value_ads.unsqueeze(-1), tau=0.001)
        value_ads_sorted = (value_ads_sorted_indices.float() @ value_ads.float().unsqueeze(-1)).squeeze(-1)
        value_org_sorted_indices = deterministic_NeuralSort(value_org.unsqueeze(-1), tau=0.001)
        value_org_sorted = (value_org_sorted_indices.float() @ value_org.float().unsqueeze(-1)).squeeze(-1)

        # Initialize queues for ads and organic results
        ads_queue = t.zeros([value_ads.shape[0], 2])
        org_queue = t.zeros([value_ads.shape[0], 2])
        value_sorted = t.hstack((value_ads_sorted[:, :2], value_org_sorted[:, :2]))
        value_sorted2 = -t.sort(-value_sorted).values

        # Fill queues based on sorted values
        for i in range(value_combined.shape[0]):
            ads_pos, org_pos = 0, 0
            for j in range(4):
                if value_sorted2[i, j] in value_ads_sorted[i]:
                    if ads_queue[i, -1] == 0:
                        ads_queue[i, ads_pos] = j
                        ads_pos += 1
                    else:
                        org_queue[i, org_pos] = j
                        org_pos += 1
                else:
                    if org_queue[i, -1] == 0:
                        org_queue[i, org_pos] = j
                        org_pos += 1
                    else:
                        ads_queue[i, ads_pos] = j
                        ads_pos += 1

        # Sort CTR and CVR values based on sorted indices
        ctr_sorted_ads = (value_ads_sorted_indices.float() @ ctr_ads.unsqueeze(-1)).squeeze(-1)
        ctr_sorted_org = (value_org_sorted_indices.float() @ ctr_og.unsqueeze(-1)).squeeze(-1)
        cvr_sorted_ads = (value_ads_sorted_indices.float() @ cvr_ads.unsqueeze(-1)).squeeze(-1)
        cvr_sorted_org = (value_org_sorted_indices.float() @ cvr_og.unsqueeze(-1)).squeeze(-1)

        # Calculate clicks and CVR
        click = t.mean(t.sum(ctr_sorted_ads[:, :2] * torch_get_position_ctr(ads_queue[:, :2], 2), axis=1) + 
                       t.sum(ctr_sorted_org[:, :2] * torch_get_position_ctr(org_queue[:, :2], 2), axis=1))
        cvr = t.mean(t.sum(cvr_sorted_ads[:, :2], 1) / 4 + t.sum(cvr_sorted_org[:, :2], 1) / 4)

        # Calculate payments
        payment = t.zeros([value_ads.shape[0], 2]).to(device)
        payment[:, 0] = ((value_ads_sorted[:, 1] * torch_get_position_ctr(ads_queue[:, 0], 1) + 
                          value_ads_sorted[:, 2] * torch_get_position_ctr(ads_queue[:, 1], 1) - 
                          value_ads_sorted[:, 1] * torch_get_position_ctr(ads_queue[:, 1], 1)) - 
                         beta * cvr_sorted_ads[:, 0] - 
                         alpha * ctr_sorted_ads[:, 0] * torch_get_position_ctr(ads_queue[:, 0], 1)) / (ctr_sorted_ads[:, 0] * torch_get_position_ctr(ads_queue[:, 0], 1))
        payment[:, 1] = ((value_ads_sorted[:, 2] * torch_get_position_ctr(ads_queue[:, 1], 1)) - 
                         beta * cvr_sorted_ads[:, 1] - 
                         alpha * ctr_sorted_ads[:, 1] * torch_get_position_ctr(ads_queue[:, 1], 1)) / (ctr_sorted_ads[:, 1] * torch_get_position_ctr(ads_queue[:, 1], 1))

        # Calculate cost
        w_spa = (value_ads_sorted_indices.float() @ w_copy.float().unsqueeze(-1)).squeeze(-1)
        b_spa = (value_ads_sorted_indices.float() @ b_copy.float().unsqueeze(-1)).squeeze(-1)
        pay = t.hstack((payment, t.zeros(batch_size, 2).to(device)))
        p = F.relu(t.mul(pay - b_spa, t.reciprocal(t.exp(w_spa))))[:, :2]
        cost = t.mean(t.sum(p * ctr_sorted_ads[:, :2] * torch_get_position_ctr(ads_queue[:, :2], 2), axis=1))

        # Calculate revenue
        revenue = cost + alpha * click - 10 * F.relu(0.1 - cvr)

        return revenue, click, cost, cvr

    def seller_backward(self, args, x, ctr_ads, ctr_og, cvr_ads, cvr_og, alpha, beta):
        output = self.forward(x, ctr_ads, ctr_og, cvr_ads, cvr_og, alpha, beta)
        loss = -output[0]
        loss.backward()

        # Update weights and biases
        self.seller_w.data.sub_(0.01 * self.seller_w.grad.data)
        self.seller_b.data.sub_(0.01 * self.seller_b.grad.data)

        # Zero gradients
        self.seller_w.grad.data.zero_()
        self.seller_b.grad.data.zero_()

class Hyper_VCG(nn.Module):
    def __init__(self, args):
        super(Hyper_VCG, self).__init__()
        self.args = args

    def forward(self, inputs, ctr_ads, ctr_og, cvr_ads, cvr_og, alpha, beta, w1, b1):
        pos_ratio = t.tensor([1., 0.8, 0.6, 0.5]).float().to(device)
        num_agents = self.args.num_agent
        batch_size = inputs.shape[0]

        # Convert inputs to tensors
        ctr_ads = t.tensor(ctr_ads).float().to(device)
        ctr_og = t.tensor(ctr_og).float().to(device)
        cvr_ads = t.tensor(cvr_ads).float().to(device)
        cvr_og = t.tensor(cvr_og).float().to(device)
        x = t.tensor(inputs).float().to(device)

        # Repeat weights and biases for batch processing
        w_copy = w1.repeat(batch_size, 1)
        b_copy = b1.repeat(batch_size, 1)

        # Compute value for ads and organic results
        vv = x * t.exp(w_copy) + b_copy
        value_ads = vv * ctr_ads + alpha * ctr_ads + beta * cvr_ads
        value_org = alpha * ctr_og + beta * cvr_og
        value_combined = t.hstack((value_ads, value_org))

        # Sort values using deterministic neural sort
        value_sorted_indices = deterministic_NeuralSort(value_combined.unsqueeze(-1), tau=0.01)
        value_sorted = (value_sorted_indices @ value_combined.unsqueeze(-1)).squeeze(-1)
        bid_sorted = (value_sorted_indices @ t.hstack((x, t.zeros(x.shape).to(device))).unsqueeze(-1)).squeeze(-1)
        ctr_sorted = (value_sorted_indices @ t.hstack((ctr_ads, ctr_og)).unsqueeze(-1)).squeeze(-1)
        cvr_sorted = (value_sorted_indices @ t.hstack((cvr_ads, cvr_og)).unsqueeze(-1)).squeeze(-1)

        # Calculate clicks and CVR
        click = t.mean(t.sum(ctr_sorted[:, :4] * pos_ratio, 1))
        cvr = t.mean(t.mean(cvr_sorted[:, :4], 1))

        # Calculate costs
        cost = t.zeros([batch_size, 4]).to(device)
        cost[:, 0] = (t.sum(value_sorted[:, 1:5] * pos_ratio, 1) - t.sum(value_sorted[:, 1:4] * pos_ratio[1:], 1) - beta * cvr_sorted[:, 0] - alpha * ctr_sorted[:, 0]) / ctr_sorted[:, 0]
        cost[:, 1] = (t.sum(value_sorted[:, 2:5] * pos_ratio[1:], 1) - t.sum(value_sorted[:, 2:4] * pos_ratio[2:], 1) - beta * cvr_sorted[:, 1] - alpha * ctr_sorted[:, 1] * 0.8) / (ctr_sorted[:, 1] * 0.8)
        cost[:, 2] = (t.sum(value_sorted[:, 3:5] * pos_ratio[2:], 1) - t.sum(value_sorted[:, 3:4] * pos_ratio[3:], 1) - beta * cvr_sorted[:, 2] - alpha * ctr_sorted[:, 2] * 0.6) / (ctr_sorted[:, 2] * 0.6)
        cost[:, 3] = (t.sum(value_sorted[:, 4:5] * pos_ratio[3:], 1) - beta * cvr_sorted[:, 3] - alpha * ctr_sorted[:, 3] * 0.5) / (ctr_sorted[:, 3] * 0.5)

        ww = t.reshape(w1.repeat(batch_size),[batch_size,num_agents])
        ww2 = t.ones([batch_size,num_agents]).to(device)
        w_copy = ((value_sorted_indices @ t.hstack((ww,ww2)).unsqueeze(-1)).squeeze(-1))[:,0:4] 
        bb = t.reshape(b1.repeat(batch_size),[batch_size,num_agents])
        bb2 = 0.1 * t.ones([batch_size,num_agents]).to(device)
        b_copy = ((value_sorted_indices @ t.hstack((bb,bb2)).unsqueeze(-1)).squeeze(-1))[:,0:4]        
        p_max_units = t.mul(cost - b_copy,t.reciprocal(t.exp(w_copy)))
        cost = F.relu(p_max_units,1) 

        # Calculate percentage of bids below threshold
        count = (bid_sorted[:, :4] < 0.001).sum().item()
        perct = count / (batch_size * num_agents)

        # Calculate final cost
        cost = t.mean(ctr_sorted[:, 0] * cost[:, 0] + 0.8 * ctr_sorted[:, 1] * cost[:, 1] + 0.6 * ctr_sorted[:, 2] * cost[:, 2] + 0.5 * ctr_sorted[:, 3] * cost[:, 3])

        # Calculate revenue
        revenue = cost + alpha * click - 10 * F.relu(0.1 - cvr)

        return revenue, cost, click, perct, cvr
